Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RF] RooFit AD refactoring #16772

Merged
merged 2 commits into from
Nov 18, 2024

Conversation

guitargeek
Copy link
Contributor

@guitargeek guitargeek commented Oct 28, 2024

  1. Avoid referencing RooFuncWrapper inside code generation context
  2. Take out RooFit code generation context outside of Detail namespace
  3. Use dispatching via interpreter for code generation, so that the code generation methods can be defined in any place (also user frameworks)

All of this cumulates in an improved developer experience for working on the RooFit code generation, which will be valuable next week at the CMS hackathon.

Here is an example of how to override the code generation function for a class in RooFit:

// To demonstrate that even when the class is a derived class, the dispatching
// will work.
class MyGaussian : public RooGaussian {
public:
    using RooGaussian::RooGaussian;

    ClassDefOverride(MyGaussian, 0);
};

namespace RooFit {

using MyPrio = Prio<3>;

// RooFit internal functions have prio 9, so we have to use a lower number to
// override it. The prio parameter is optional and only relevant for
// overriding: if you implement codegen for a new class you don't need it.

void codegenImpl(RooGaussian &arg,
                 CodegenContext &ctx,
                 MyPrio p)
{

   std::cout << "User override" << std::endl;

   // Fall through to next implementation
   return codegenImpl(arg, ctx, p.next());
}

std::string codegenIntegralImpl(RooGaussian &arg,
                                int code,
                                const char *rangeName,
                                CodegenContext &ctx,
                                MyPrio p)
{
   std::cout << "User override (integral)" << std::endl;

   // Fall through to next implementation
   return codegenIntegralImpl(arg, code, rangeName, ctx, p.next());
}

} // namespace RooFit

void repro()
{
   RooRealVar x("x", "x", -10, 10);
   RooRealVar mean("mean", "mean of gaussian", 1, -10, 10);
   RooRealVar sigma("sigma", "width of gaussian", 1, 0.1, 10);
   MyGaussian gauss("gauss", "gaussian PDF", x, mean, sigma);

   std::unique_ptr<RooAbsReal> integ{gauss.createIntegral(x)};

   // Generate function for unnormalized Gaussian and the integral
   std::string funcName = RooFit::CodegenContext{}.buildFunction(gauss);
   std::string funcNameInteg = RooFit::CodegenContext{}.buildFunction(*integ);

   // Check out generated code
   gInterpreter->ProcessLine(funcName.c_str());
   gInterpreter->ProcessLine(funcNameInteg.c_str());
}

Output:

User override
User override (integral)
.... some code dump ....

@guitargeek guitargeek self-assigned this Oct 28, 2024
@guitargeek guitargeek force-pushed the roofit_ad_refactoring branch 4 times, most recently from 94aff8c to 423cfbf Compare October 28, 2024 22:41
Copy link

github-actions bot commented Oct 29, 2024

Test Results

    18 files      18 suites   4d 2h 32m 0s ⏱️
 2 678 tests  2 678 ✅ 0 💤 0 ❌
46 342 runs  46 342 ✅ 0 💤 0 ❌

Results for commit 63c0cb8.

♻️ This comment has been updated with latest results.

@guitargeek guitargeek force-pushed the roofit_ad_refactoring branch 7 times, most recently from 77419ea to 9457437 Compare November 13, 2024 15:13
@guitargeek guitargeek marked this pull request as ready for review November 13, 2024 22:16
@egpbos
Copy link
Contributor

egpbos commented Nov 15, 2024

Just thinking about this a bit after the Stats meeting yesterday, I was wondering: does the dispatching actually work well with class hierarchies? Or would a user need to override for every subclass separately as well? Possibly this is not a problem in practice for RooFit, if the user would mainly be interested in modifying the codegen implementation functions for "final" classes. But I was just considering it in the broader context of emulating multiple dispatch :)

Note btw that there are also existing libraries that do multiple dispatch emulation, like https://github.com/jll63/yomm2. That one also requires some boilerplate (e.g. registration), but your solution does too (the priority flags), so may be interesting to weigh those against each other. Note that I have no strong opinion either way; I'm just interested to see how this turns out, because multiple dispatch is awesome 😄

@guitargeek
Copy link
Contributor Author

Hi @egpbos, thank you very much for your comment!

I would like to not introduce further dependencies to ROOT, so let's try to implement this with the interpreter that we already have :)

does the dispatching actually work well with class hierarchies?

No, it doesn't. I have thought about this quite a bit, and realized that this is sort of hard to implement at the same time, at least how it is done now.

For example, if you have MyGaussian deriving from RooGaussian, which overload would you expect to hit?

Priority<4>, RooGaussian&
Priority<5>, MyGaussian& // higher number means lower priority

I would expect it to resolve to the second one. However, right now in my implementation, "Prio" takes priority over argument type matching, so this is unintuitive. This is hard to unify without additional boilerblate....

One solution would be to just leave how things are in this PR, which would mean that you have to write some boilerplate when deriving classes where you want to keep the code generation the same:

void codegenImpl(Prio<0> p, MyGaussian &arg, CodgenContext &ctx) {
   codegenImpl(p, static_cast<RooGaussian&>(arg), ctx);
}

The other solution is to drop the priority mechanism, but I think it is quite useful...

@guitargeek
Copy link
Contributor Author

...that one also requires some boilerplate (e.g. registration), but your solution does too (the priority flags)

Just to clarify: the priority flags are not boilerplate but a feature. They can also be dropped, but then it would not be possible anymore to "override" behavior for existing classes in usercode.

@guitargeek
Copy link
Contributor Author

Or the ambiguity I mentioned can be resolved by requiring that priority has to decrease (e.g. higher priority) with depth in the class hierarchy. This can't be enforced, but it can be documented.

@guitargeek guitargeek force-pushed the roofit_ad_refactoring branch 2 times, most recently from 95f5480 to 63c0cb8 Compare November 15, 2024 17:16
@guitargeek
Copy link
Contributor Author

I have now implemented the comments I have received by @egpbos and also @vgvassilev in private:

  1. If you create a new derived class, it will correctly dispatch to the function that takes the base class (see demo in the PR description)
  2. The Prio argument is now at the final position and optional

@guitargeek guitargeek requested review from egpbos and vgvassilev and removed request for bellenot November 15, 2024 17:27
Copy link
Member

@lmoneta lmoneta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Jonas for this big refactoring and improvement in the code generation for AD.

LGTM!

Copy link
Member

@vgvassilev vgvassilev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a few tests checking the expected scenarios. This would help us see the bigger picture.

@egpbos
Copy link
Contributor

egpbos commented Nov 15, 2024

Ah that's right, the priority mechanism is a separate thing. Your idea about documenting "sensible" priority ordering makes sense too. Btw I think it's actually not even a thing for Julia, because there is no inheritance there.

Edit: oh weird, I only now see all your later comments. The above was a response only to the comments between 1 and 2 PM 😄

Copy link
Contributor

@egpbos egpbos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this concept. Maybe there are some details to iron out. Perhaps consider putting this in an Experimental namespace, or is it already?

Comment on lines 134 to 153
template <class Arg_t, int P>
std::string codegenIntegralImpl(Arg_t &arg, int code, const char *rangeName, CodegenContext &ctx, Prio<P> p)
{
if constexpr (std::is_same<Prio<P>, PrioLowest>::value) {
return codegenIntegralImpl(arg, code, rangeName, ctx);
} else {
return codegenIntegralImpl(arg, code, rangeName, ctx, p.next());
}
}

template <class Arg_t>
struct CodegenIntegralImplCaller {

static auto call(RooAbsReal &arg, int code, const char *rangeName, RooFit::CodegenContext &ctx)
{
return codegenIntegralImpl(static_cast<Arg_t &>(arg), code, rangeName, ctx, PrioHighest{});
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that the Caller first tries to call the prio 1 (highest) implementation, and then will loop through all 10 using the Impl template until it finds one that is actually implemented for the Arg_t already so doesn't need to use the template?

Few questions:

  1. Do you actually need all 10 priority levels? How about replacing that with just an overload boolean? It will save some looping through empty functions.
  2. I'm not sure, but there may be some pathological situations. For instance, what if in an interactive session you call some function, the tree of templates gets instantiated and then afterwards you try to define a new overload? Won't that overload then already have been instantiated from the Impl template? Another possible situation would be something similar to the above but in separate compilation units.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great comments!

  1. I was hoping the optimizer gets rid of these loop levels, which are "statically" clear to the interpreter. So in principle they should be optimized out. But it's an interesting question: I'll benchmark how expensive the levels are
  2. You're right, this is a clear limitation, but design: the function pointer that was resolved is cached in an unordered_map, so speed up the lookup for the same class later. I'll check what would happen if the caching is disabled, that's an interesting question!

It was a bit awkward that the RooFuncWrapper instantiated a code
generation context, which itself had to use the RooFuncWrapper via a
reference.

This commit suggests to refactor the code such that the context doesn't
need to know anything about the RooFuncWrapper.
@guitargeek guitargeek force-pushed the roofit_ad_refactoring branch from 63c0cb8 to 2d48ca9 Compare November 18, 2024 19:29
Use dispatching via interpreter for code generation, so that the code
generation methods can be defined in any place (also user frameworks).

This also includes a priority mechanism for easy overriding.

We are expecting users to interact with this class, since it is
essential to implement code generation for your own RooFit primitives.

Putting the class in the "Detail" namespace sends the wrong message and
normalized it for users to ignore the convention that "Detail" or
"Internal" stuff should not be used. Therefore, it was moved into the
"Experimental" namespace.

While removing the namespace, also rename the class to a shorter and
more succinct name, matching the "codegen" evaluation backend name.
"Codegen" also sounds more friendly than "CodeSquash", which reminds me
of "fly squash".
@guitargeek guitargeek force-pushed the roofit_ad_refactoring branch from 2d48ca9 to 5dc0353 Compare November 18, 2024 19:36
Copy link
Member

@vgvassilev vgvassilev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me as a way forward here to enable early adoption in other clients such as Combine.

We should iterate on generalizing the coden dispatcher, into a generic dispatcher of the signature: dispatch(<RooClass>, <DispatchedToFnName>, <Context>).

@guitargeek guitargeek merged commit 96e023b into root-project:master Nov 18, 2024
18 of 20 checks passed
@guitargeek guitargeek deleted the roofit_ad_refactoring branch November 18, 2024 21:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants